Introduction¶
Every machine learning task involves two fundamental components:
Optimization: The process of tuning a model to achieve the best possible performance on the training data, which is the essence of learning in machine learning.
Generalization: The process of applying the trained model to new data that it has never seen before.
The main challenge in machine learning is to find a balance between optimization and generalization. The ultimate goal is to achieve good generalization, but this cannot be directly controlled. Instead, the model can only be adjusted based on its training data to increase its chances of generalizing well.
Overfitting occurs when the model cannot generalize, but well optimized to fit too closely to the training dataset instead. Overfitting happens in every machine-learning problem. Learning how to deal with overfitting is essential to mastering machine learning.
Overfitting¶
Let's consider a simple polynomial fitting example to better illustrate overfitting.
Suppose we have observed 13 data points (called training data because we use data to train our forecasting model) that are generated from the function $\sin(2 \pi x)$ with some small perturbations. The figure below illustrates these training data in blue circles and the curve of $\sin(2 \pi x)$.
Our task is to predict the underlying data pattern by using simple polynomial functions with degree 1, 3 and 9. In other words, without knowing the true function $\sin(2 \pi x)$, we need to
- fit the given 13 training data points with simple polynomial functions,
- choose one of the polynomial functions to better represent the true function $\sin(2 \pi x)$ for forecasting future observations.
If you feel that the numerical example is not easy to comprehend, you may consider an application as follows.
The green $\sin(2 \pi x)$ curve is the true underlying business cycle (which has an up-and-down pattern), and blue circles are observed sales. In an ideal world, sales are fully described by the business cycle. However, in the real world sales may be perturbed from the true business cycle by other random facts (such as weather). That explains why those blue circles are generated from the $\sin(2 \pi x)$ curve with small perturbations. Our goal is to offer an understanding (i.e., a polynomial function) to properly explain the underlying business cycle.
The figure below shows the fitting polynomials with degree 1, 3 and 9 which are red curve (or line) in subgraphs.
Note that a simple polynomial function with degree 1 is a linear function. Hence, the fitting procedure with a linear function is the well-known linear regression.
If we knew the underlying true function $\sin(2 \pi x)$, then it is quite clear that
- the linear function is under-fitting because it doesn't explain the up-and-down pattern in the training data;
- the polynomial function with degree 9 is over-fitting. Although the polynomial curve accurately describe the observed training data, it focuses too much on noises and does not correctly interpret the pattern of the underlying function LaTeX: \sin(2 \pi x);
- the polynomial with degree 3 is the best because it balances the information and noises and extracts the valuable information from the training data.
However, as in many real world applications, we do not know the underlying true pattern of the training data, i.e., the function $\sin(2 \pi x)$, in our example. We can only observe a figure as below. Then, the figure poses a simple question to all decision makers: which polynomial is the best to present the data pattern for future forecasting?
Apparently, we can rule out linear regression since it did a poor job on explaining the training data. Although 9-degree polynomial perfectly fits the training data, it performs horribly on the testing data. The Y values correspond to 3 testing data points are actually out of the chart for 9-degree polynomial. Hence, 3-degree polynomial is the obvious winner based on the testing data.
The example illustrates the law of parsimony (Occam's razor): we generally prefer a simpler model because simpler models track only the most basic underlying patterns and can be more flexible and accurate in predicting the future!
This polynomial fitting example demonstrates the use of a validation dataset for model selection. The best practice is to use K-fold cross-validation, which will be introduced in the section of predicting house prices.
To prevent a model from learning irrelevant or misleading patterns found in the training data, the best solution is to obtain more training data. A model trained on more data will naturally generalize better.
If obtaining more data is not possible, the next best solution is to limit the amount of information the model can store. By forcing the model to focus on the most prominent patterns, it has a better chance of generalizing well.
The process of mitigating overfitting by reducing the model's ability to learn irrelevant patterns is called regularization. The most commonly used regularization techniques for neural networks are:
- Reducing the network's size
- Adding weight regularization
- Adding dropout
We will discuss and demonstrate the application of these techniques in the later sections.